{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <div style='color:#ff77e0;font-weight:800;'>训练中断后接着训练的代码</div>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "from nes_py.wrappers import JoypadSpace\n",
    "import gym_super_mario_bros\n",
    "from gym_super_mario_bros.actions import SIMPLE_MOVEMENT\n",
    "import time\n",
    "from matplotlib import pyplot as plt\n",
    "from gym.wrappers import GrayScaleObservation\n",
    "from stable_baselines3.common.monitor import Monitor\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "from stable_baselines3.common.vec_env import VecFrameStack\n",
    "import os\n",
    "from stable_baselines3 import PPO\n",
    "\n",
    "from stable_baselines3.common.results_plotter import load_results, ts2xy\n",
    "import numpy as np\n",
    "from stable_baselines3.common.callbacks import BaseCallback\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym_super_mario_bros.make('SuperMarioBros-v0')\n",
    "env = JoypadSpace(env, SIMPLE_MOVEMENT)\n",
    "\n",
    "monitor_dir = r'./monitor_log/'\n",
    "os.makedirs(monitor_dir,exist_ok=True)\n",
    "env = Monitor(env,monitor_dir)\n",
    "\n",
    "env = GrayScaleObservation(env,keep_dim=True)\n",
    "env = DummyVecEnv([lambda: env])\n",
    "env = VecFrameStack(env,4,channels_order='last')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# ! 设置对应的参数\n",
    "\n",
    "tensorboard_log = r'./tensorboard_log/'\n",
    "learning_rate = 1e-7\n",
    "model = PPO(\"CnnPolicy\", env, verbose=1,\n",
    "            tensorboard_log = tensorboard_log,\n",
    "            learning_rate = learning_rate,\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# ! 设置自己模型对应的路径\n",
    "# save_model_dir = r'F:\\\\RL_Mario1\\\\best_model\\\\model_5000.zip'\n",
    "save_model_dir = r'F:\\\\RL_Mario\\\\best_model\\\\_1620000.zip'\n",
    "\n",
    "model.set_parameters(save_model_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SaveOnBestTrainingRewardCallback(BaseCallback):\n",
    "    \"\"\"\n",
    "    Callback for saving a model (the check is done every ``check_freq`` steps)\n",
    "    based on the training reward (in practice, we recommend using ``EvalCallback``).\n",
    "\n",
    "    :param check_freq: (int)\n",
    "    :param log_dir: (str) Path to the folder where the model will be saved.\n",
    "      It must contains the file created by the ``Monitor`` wrapper.\n",
    "    :param verbose: (int)\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, check_freq, save_model_dir, verbose=1):\n",
    "        super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)\n",
    "        self.check_freq = check_freq\n",
    "        self.save_path = os.path.join(save_model_dir, 'best_model/')\n",
    "        self.best_mean_reward = -np.inf\n",
    "\n",
    "    # def _init_callback(self) -> None:\n",
    "    def _init_callback(self):\n",
    "        # Create folder if needed\n",
    "        if self.save_path is not None:\n",
    "            os.makedirs(self.save_path, exist_ok=True)\n",
    "\n",
    "    # def _on_step(self) -> bool:\n",
    "    def _on_step(self):\n",
    "        if self.n_calls % self.check_freq == 0:\n",
    "            print('self.n_calls: ',self.n_calls)\n",
    "            model_path1 = os.path.join(self.save_path, 'model_{}'.format(self.n_calls))\n",
    "            self.model.save(model_path1)\n",
    "\n",
    "        return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_model_dir = r'F:\\\\RL_Mario1\\\\'\n",
    "callback1 = SaveOnBestTrainingRewardCallback(10000, save_model_dir)\n",
    "\n",
    "model.learn(total_timesteps=3000000,callback=callback1)\n",
    "# model.save(\"mario_model\")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "9465aae7e0ab1403d672807d1a0963d86dbda2f584fbe3054c36cf78311c6c77"
  },
  "kernelspec": {
   "display_name": "Python 3.8.11 ('pytorch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
